"""Plot feature data."""
import matplotlib.pyplot as plt
import pandas as pd
import json
from emodel_generalisation.mcmc import get_mean_sd
import seaborn as sns

from emodel_generalisation.mcmc import load_chains
from emodel_generalisation.mcmc import plot_corner
from emodel_generalisation.mcmc import get_2d_correlations


def plot_corr(corr_df):
    """Plot corr matrix."""
    plt.figure(figsize=(6, 4))
    ax = plt.gca()

    sns.heatmap(
        data=corr_df,
        ax=ax,
        vmin=0.8,
        vmax=-0.8,
        cmap="coolwarm",
        linewidths=0.5,
        linecolor="k",
        cbar_kws={"label": "pearson", "shrink": 0.3},
        xticklabels=True,
        yticklabels=True,
        square=True,
    )
    plt.tight_layout()


if __name__ == "__main__":

    mcmc_df = load_chains("../../mcmc_run/run_df.csv", base_path="../../mcmc_run/")
    mcmc_df = mcmc_df[mcmc_df.cost < 2.0]

    mask_runaway = (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.burst_runaway"] < 0.05) & (
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.time_to_last_spike"] > 20000
    )
    mask_ecel = (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"] < 5) & (
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.tonic_after_burst"] == 0.0
    )
    mask_spp = (
        ~mask_runaway
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"] < 10)
        & (mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.all_burst_number"] > 5)
    )

    _mcmc_df = mcmc_df.copy()
    print(mcmc_df["normalized_parameters"])
    _mcmc_df = mcmc_df[
        [
            ("normalized_parameters", p)
            for p in [
                "g_pas.all",
                "gnabar_hh2.somatic",
                "gkbar_hh2.somatic",
                "gcabar_it2.basal",
                "constant.distribution_increase",
                "shift_it2.somadend",
                "gbar_ican.basal",
                "gkbar_iahp.basal",
            ]
        ]
    ]
    print(_mcmc_df["normalized_parameters"].columns)
    model = "ecel"
    models_df = pd.DataFrame()
    for c, model in zip(["C0", "C1", "C2"], ["ecel", "spp", "runaway"]):
        _df = pd.read_csv(f"params_{model}.csv", header=[0, 1])
        _df["color"] = c
        models_df = pd.concat([models_df, _df])
    print(models_df)
    plot_corner(
        _mcmc_df,
        sort_params=False,
        filename="corner.pdf",
        highlights=models_df,
        cmap="Greys",
        normalize=True,
        with_pearson=True,
    )

    features = mcmc_df.features.columns
    full_features = [
        "Step_ReboundBurst_burst.soma.v.all_burst_number",
        "Step_ReboundBurst_burst.soma.v.spikes_per_burst",
        "Step_ReboundBurst_burst.soma.v.burst_mean_freq",
        "Step_ReboundBurst_burst.soma.v.peak_voltage",
        "Step_ReboundBurst_burst.soma.v.inv_first_ISI",
        "Step_ReboundBurst_burst.soma.v.AP2_AP1_peak_diff",
        "Step_ReboundBurst_burst.soma.v.AHP_depth_abs",
        "Step_ReboundBurst_burst.soma.v.time_to_first_spike",
        "Step_ReboundBurst_burst.soma.v.postburst_min_values",
    ]

    with open("../../mcmc_run/config/features_prots.json") as f:
        efeatures = json.load(f)
    plt.figure()
    fig, axs = plt.subplots(1, len(full_features), figsize=(1.5 * len(full_features), 3))
    for ax, feat in zip(axs, full_features):
        f_mean, f_std = get_mean_sd(efeatures, feat)
        d = mcmc_df["features"][feat]
        ax.hist(
            d[mask_ecel],
            bins=30,
            orientation="horizontal",
            color="C0",
            histtype="step",
            range=(d.min(), d.max()),
        )
        ax.hist(
            d[mask_spp],
            bins=30,
            orientation="horizontal",
            color="C1",
            histtype="step",
            range=(d.min(), d.max()),
        )
        ax.hist(
            d[mask_runaway],
            bins=30,
            orientation="horizontal",
            color="C2",
            histtype="step",
            range=(d.min(), d.max()),
        )
        ax.spines.right.set_visible(False)
        ax.spines.top.set_visible(False)
        ax.set_ylabel(feat.split(".")[-1])
        ax.set_xlabel("# models")
        if not feat.endswith("number"):
            ax.axhline(d.mean(), c="b", label="model mean")
            ax.axhline(d.mean() + d.std(), c="b", ls="--", label="model 1sd")
            if d.mean() - d.std() > 0:
                ax.axhline(d.mean() - d.std(), c="b", ls="--")
            ax.axhline(f_mean, c="k", label="exp. mean")
            if f_mean - f_std > 0:
                ax.axhline(f_mean - f_std, c="k", ls="--")
            ax.axhline(f_mean + f_std, c="k", ls="--", label="exp. 1sd")
            if f_mean - 2 * f_std > 0:
                ax.axhline(f_mean - 2 * f_std, c="k", ls="-.")
            ax.axhline(f_mean + 2 * f_std, c="k", ls="-.", label="exp. 2sd")
    plt.tight_layout()
    plt.savefig("feature_hist.pdf")

    data_df = mcmc_df["features"].reset_index(drop=True)
    data_df = data_df[full_features]
    data_df.columns = [f.split(".")[-1] for f in data_df.columns]
    features = [
        "all_burst_number",
        "spikes_per_burst",
        "burst_mean_freq",
        "peak_voltage",
        "inv_first_ISI",
        "AP2_AP1_peak_diff",
        "AHP_depth_abs",
        "time_to_first_spike",
        "postburst_min_values",
    ]
    mcmc_df = mcmc_df.drop(
        columns=[c for c in mcmc_df.columns if (c[0] in "features") and (c[1] not in full_features)]
    )
    corr_df = get_2d_correlations(mcmc_df, x_col="features", y_col="features", tpe="pearson")
    corr_df.columns = [f.split(".")[-1] for f in corr_df.columns]
    corr_df.index = [f.split(".")[-1] for f in corr_df.index]
    plot_corr(corr_df)
    plt.savefig("2d_corr.pdf")

    corr_df = get_2d_correlations(
        mcmc_df, x_col="features", y_col="normalized_parameters", tpe="pearson"
    )
    corr_df.index = [f.split(".")[-1] for f in corr_df.index]
    corr_df = corr_df[
        [
            "g_pas.all",
            "shift_it2.somadend",
            "gnabar_hh2.somatic",
            "gkbar_hh2.somatic",
            "gbar_ican.basal",
            "gkbar_iahp.basal",
            "gcabar_it2.basal",
        ]
    ]
    plot_corr(corr_df)
    plt.savefig("2d_corr_params.pdf")

    plt.figure(figsize=(3, 3))
    dy, dx = (
        mcmc_df["features", "Step_ReboundBurst_burst.soma.v.inv_first_ISI"],
        mcmc_df["parameters", "gcabar_it2.basal"],
    )
    plt.scatter(
        dx,
        dy,
        marker=".",
        s=1,
        c="C5",
        rasterized=True,
    )
    plt.ylabel("inv. first ISI")
    plt.xlabel("gcabar_it2.basal")
    plt.tight_layout()
    plt.savefig("corr_params_1.pdf")

    plt.figure(figsize=(3, 3))
    dy, dx = (
        mcmc_df["features", "Step_ReboundBurst_burst.soma.v.spikes_per_burst"],
        mcmc_df["parameters", "gkbar_iahp.basal"],
    )
    plt.scatter(dx, dy, marker=".", s=1, c="k", rasterized=True)
    plt.ylabel("spikes_per_burst")
    plt.xlabel("gkbar_iahp.basal")
    plt.tight_layout()
    plt.savefig("corr_params_2.pdf")

    plt.figure(figsize=(3, 3))
    dy, dx = (
        mcmc_df["features", "Step_ReboundBurst_burst.soma.v.time_to_first_spike"],
        mcmc_df["parameters", "shift_it2.somadend"],
    )
    plt.scatter(dx, dy, marker=".", s=1, c="C4", rasterized=True)
    plt.ylabel("time_to_first_spike")
    plt.xlabel("shift_it2.somadend")
    plt.tight_layout()
    plt.savefig("corr_params_3.pdf")

    plt.figure(figsize=(3, 3))
    dy, dx = data_df["inv_first_ISI"], data_df["burst_mean_freq"]
    plt.scatter(dx, dy, marker=".", s=1, c="k", rasterized=True)
    plt.ylabel("inv. first ISI")
    plt.xlabel("burst mean freq")
    plt.tight_layout()
    plt.savefig("corr_1.pdf")

    plt.figure(figsize=(3, 3))
    dy, dx = data_df["spikes_per_burst"], data_df["burst_mean_freq"]
    plt.scatter(dx, dy, marker=".", s=1, c="k", rasterized=True)
    plt.ylabel("spikes per burst")
    plt.xlabel("burst mean freq")
    plt.tight_layout()
    plt.savefig("corr_2.pdf")

    plt.figure(figsize=(3, 3))
    dy, dx = data_df["postburst_min_values"], data_df["time_to_first_spike"]
    plt.scatter(dx, dy, marker=".", s=1, c="k", rasterized=True)
    plt.ylabel("postburst min values")
    plt.xlabel("time to first spike")
    plt.tight_layout()
    plt.savefig("corr_3.pdf")

    plt.figure(figsize=(3, 3))
    dy, dx = data_df["inv_first_ISI"], data_df["AHP_depth_abs"]
    plt.scatter(dx, dy, marker=".", s=1, c="k", rasterized=True)
    plt.ylabel("inv. first ISI")
    plt.xlabel("AHP depth abs")
    plt.tight_layout()
    plt.savefig("corr_4.pdf")

    plt.figure(figsize=(3, 3))
    dy, dx = (
        mcmc_df["features"]["Step_ReboundBurst_burst.soma.v.spikes_per_burst"],
        mcmc_df["parameters"]["gkbar_iahp.basal"],
    )
    plt.scatter(dx, dy, marker=".", s=1, c="k", rasterized=True)
    plt.ylabel("spikes per burst")
    plt.xlabel("iahp basal")
    plt.tight_layout()
    plt.savefig("corr_5.pdf")
